import numpy as np
import torch
import torch.utils.data
from tqdm import tqdm
from collections import defaultdict
from src.dataset import *
from models.model import *
from src.wasserstein import *
from src.utils import *


def tpr(winv, woutv, level=0.95):
    assert level < 1 and level > 0
    threshold = np.quantile(winv.to('cpu'), level)
    print(f"{level*100}% TNR Threshold: {threshold}")
    fpr = woutv[woutv <= threshold].shape[0] / float(woutv.shape[0])
    tpr = 1 - fpr
    print(f"TPR at {level*100}%  TNR: {tpr}")
    return tpr, threshold


def loader_wass(data_loader, D):
    wass_dists = []
    for (img, _) in tqdm(data_loader):
        out = D(img.to(DEVICE))
        wass_dist = ood_wass_loss(torch.softmax(out, dim=-1))
        wass_dists.append(wass_dist)
    return torch.cat(wass_dists, dim=0)


def print_stats(stat, name, precision=5):
    print(f"{name}: {stat}")
    mad = np.mean(np.abs(np.mean(stat) - stat))
    print(
        f"mean: {np.round(np.mean(stat), precision)} | std: {np.round(np.std(stat), precision)} | MAD: {np.round(mad, precision)}")


class EVALER():
    def __init__(self, xin_t, xin_v, xin_v_loader, xout_v, xout_v_loader,
                 n_ood, log_dir, num_classes):
        self.n_ood = n_ood
        self.log_dir = log_dir
        # DATASETS
        self.xin_t = xin_t  # InD training dataset & loader
        # self.xin_t_loader = xin_t_loader
        self.xin_v = xin_v  # InD Testing dataset & loader
        self.xin_v_loader = xin_v_loader
        self.xout_v = xout_v  # OoD Testing dataset & loader
        self.xout_v_loader = xout_v_loader
        # METHODOLOGY
        self.num_classes = num_classes
        # Statistics - wasserstein distance
        self.winv, self.woutv = [],  []
        # Statistics - TPR at x% TNR
        self.tpr95, self.tpr99 = [], []
        self.tpr95_thresh, self.tpr99_thresh = [], []

    def save(self, path):
        # Do not save unnecessary stuffs
        self.xin_t, self.xin_v = None, None  # InD Testing dataset & loader
        self.xin_v_loader = None
        self.xout_v = None  # OoD Testing dataset & loader
        self.xout_v_loader = None
        torch.save(self, path)

    def evaluate(self, D):

        # Compute relevant statistics
        print("Computing evaluation statistics...")
        _, yxoutv = tuple_list_to_tensor(self.xout_v)
        print("> Evaluating InD Wasserstein distances...")
        winv = loader_wass(self.xin_v_loader, D)
        print("> Evaluating OoD Wasserstein distances...")
        woutv = loader_wass(self.xout_v_loader, D)
        self.winv.append(winv)
        self.woutv.append(woutv)

        # Test model performance
        tpr_95, tpr_95_thresh = tpr(winv, woutv, 0.95)
        tpr_99, tpr_99_thresh = tpr(winv, woutv, 0.99)
        self.tpr95.append(tpr_95)
        self.tpr95_thresh.append(tpr_95_thresh)
        self.tpr99.append(tpr_99)
        self.tpr99_thresh.append(tpr_99_thresh)

    def display_stats(self):
        # Overall stats
        print("\n" + line())
        print("Overall Statistics")
        print_stats(self.tpr95, "TPR@95TNR")
        print_stats(self.tpr95_thresh, "TPR@95TNR-Threshold")
        print_stats(self.tpr99, "TPR@99TNR")
        print_stats(self.tpr99_thresh, "TPR@99TNR-Threshold")
        print("\n" + line())
